#--------------------------------------------------#
#   AFI Four-item construction                     #
#--------------------------------------------------#

#### Preliminaries ####

R.version$version.string

# clear workspace
rm(list=ls())

# load packages
library(here)
library(tidyverse)
library(rworldmap) # maps
library(magrittr)
library(reshape2)
library(ggpubr)
library(leaflet)
library(plotly)
library(gapminder)
library(countrycode)
library(ggrepel)
library(ggthemes)
library(lubridate)
library(dotwhisker)


library(coda)
library(MASS)
library(rjags)
library(parallel)
library(magrittr)
library(dplyr)
library(tools)
library(devtools)

## Install local v-dem tools package
devtools::install_local("vutils")
library(vutils)

INDIR <- "data/bfa/input"
OUTDIR <- "results"
UTABLE <- "data/bfa/country_unit.rds"

dir.create(OUTDIR, showWarnings = F, recursive = T)

stopifnot(dir.exists(INDIR))
stopifnot(dir.exists(OUTDIR))
stopifnot(file.exists(UTABLE))

data_storage <- "data"


#### Data Management BFA Input ####

civic_and_academic_space_input <- read.csv(file.path(data_storage, "posteriors/v13/civic_and_academic_space/mm_civic_and_academic_space_variables.csv"))

VARS <- c("v2cafres", "v2cafexch", "v2cainsaut", "v2casurv", "v2clacfree")

civic_and_academic_space_input_subset <- civic_and_academic_space_input %>%
  filter(var_name %in% c(VARS))

v2cafres_input <- civic_and_academic_space_input_subset %>%
  filter(var_name == "v2cafres") %>%
  mutate(name = paste(country_text_id, historical_date)) %>%
  dplyr::select(-c(var_name, country_text_id, historical_date)) %>%
  dplyr::select(name, starts_with("V")) %>%
  filter(!str_detect(name, "A_"))

v2cafexch_input <- civic_and_academic_space_input_subset %>%
  filter(var_name == "v2cafexch") %>%
  mutate(name = paste(country_text_id, historical_date)) %>%
  dplyr::select(-c(var_name, country_text_id, historical_date)) %>%
  dplyr::select(name, starts_with("V")) %>%
  filter(!str_detect(name, "A_"))

v2cainsaut_input <- civic_and_academic_space_input_subset %>%
  filter(var_name == "v2cainsaut") %>%
  mutate(name = paste(country_text_id, historical_date)) %>%
  dplyr::select(-c(var_name, country_text_id, historical_date)) %>%
  dplyr::select(name, starts_with("V")) %>%
  filter(!str_detect(name, "A_"))

v2casurv_input <- civic_and_academic_space_input_subset %>%
  filter(var_name == "v2casurv") %>%
  mutate(name = paste(country_text_id, historical_date)) %>%
  dplyr::select(-c(var_name, country_text_id, historical_date)) %>%
  dplyr::select(name, starts_with("V")) %>%
  filter(!str_detect(name, "A_"))

write.csv(v2cafres_input, file.path(data_storage, "bfa/input/v2cafres.10000.Z.sample.csv"), row.names = FALSE)
write.csv(v2cafexch_input, file.path(data_storage, "bfa/input/v2cafexch.10000.Z.sample.csv"), row.names = FALSE)
write.csv(v2cainsaut_input, file.path(data_storage, "bfa/input/v2cainsaut.10000.Z.sample.csv"), row.names = FALSE)
write.csv(v2casurv_input, file.path(data_storage, "bfa/input/v2casurv.10000.Z.sample.csv"), row.names = FALSE)

##### Academic Freedom Index (Four-Item) #####

SAVE.NAME <- "academic_freedom_index"
VARS <- c("v2cafres", "v2cafexch", "v2cainsaut", "v2casurv")

ITER <- 200
BURNIN <- 10000
MCMC <- 10000
THIN <- 200


sprintf("%d runs with %d sampling iterations, %d burnin, and %d thin",
        ITER, MCMC, BURNIN, THIN) %>% info

utable <- readRDS(UTABLE)
utable_names <- with(utable, paste(country_text_id, historical_date))


# To preserve gaps we need to find where there's jumps in our
# utable. Extract those dates to insert as NA in our individual C vars
# so that when we stretch we don't fill across those periods.
ll <- with(utable, split(year, country_text_id)) %>%
  lapply(function(v) {
    v <- sort(v)
    gapstart <- v[c(diff(v) > 1, F)]
    
    if (length(gapstart) > 0)
      as.Date((gapstart + 1) %^% "-12-31")
  }) %>% Filter(Negate(is.null), .)

gap_dates <- lapply(names(ll), function(s) paste(s, ll[[s]])) %>% unlist


# Load each variable
vars.ll <- lapply(setNames(VARS, VARS), function(varname) {
  m <- locate_z.sample(INDIR, varname) %>% read_matrix
  colnames(m) <- rep(varname, ncol(m))
  
  # If we have any vignettes, chop them off!
  out <- m[!is_vignette(rownames(m)), seq(1, by = ncol(m) / ITER, length.out = ITER)]
  
})

###
# Start by conforming the dimensions b/w all matrices.
cvar_names <- Reduce(union, lapply(vars.ll, rownames))
vars.ll <- lapply(vars.ll, vutils::stretch, cvar_names, gaps = F, preserve.na = T)

# We'll need this after running the model to set the rows we want
# missing (50% of input vars were missing) to NA so that we don't
# stretch over those dates
missing.ma <- lapply(vars.ll, function(ma) rowSums(is.na(ma)) == ncol(ma)) %>%
  do.call(rbind, .)

dropped_dates <- colnames(missing.ma)[colMeans(missing.ma) > .5]

if (length(dropped_dates) > 0) {
  sprintf("Found %d country-dates with >50%% missingness",
          length(dropped_dates)) %>% info
}

vars.ll <- lapply(vars.ll, function(m) m[!rownames(m) %in% dropped_dates,, drop = F])
cvar_names <- setdiff(cvar_names, dropped_dates)

stopifnot(do.call(all_identical, lapply(vars.ll, rownames)))

###
# Initial values. n_obs = total output length of `xi`, i.e. total
# number of unique country-dates.
n_pars <- length(VARS)
n_obs <- length(cvar_names)

sprintf("Found %d total obs", n_obs) %>% info

inits.ll <- lapply(1:4, function(i) {
  list(gamma = mvrnorm(n_pars, c(0, 0), diag(.01, 2)),
       tau = rgamma(n_pars, .01, .01),
       xi = rnorm(n_obs, 0, 1))
})


posteriors <- mclapply(1:ITER, function(i) {
  info("Running model " %^% i)
  full.ma <- do.call(cbind, lapply(vars.ll, function(ma) ma[, i]))
  
  
  # Normalize prior to running the model
  full.ma <- scale(full.ma)
  
  input.data <- list(n = nrow(full.ma), # n_obs
                     p = ncol(full.ma), # n_pars
                     y = full.ma)
  
  # Divide total runs into 4 groups, each gets same initial values
  inits <- inits.ll[[findInterval(i, c(25, 50, 75)) + 1]]
  
  model <- jags.model(file = "bfa.jag", data = input.data,
                      inits = inits, quiet = T)
  
  update(model, BURNIN)
  mcmc <- coda.samples(model, c("xi", "gamma", "omega"), MCMC, THIN)
  
  # Trust me, life is better when the country-dates are saved
  # together with the posterior object.
  b <- grepl("xi", colnames(mcmc[[1]]), perl = T)
  colnames(mcmc[[1]])[b] <- rownames(full.ma) %^% "_" %^% colnames(mcmc[[1]])[b]
  
  as.mcmc(mcmc)
}, mc.cores = 1)
mc_assert(posteriors)

info("JAGS models finished")

###
# Check for convergence for each parameter by dividing the runs into
# four pseudo "chains" and then run the Gelman & Rubin diagnostic.
cuts <- findInterval(1:100, c(25, 50, 75), left.open = T)
gelman <- split(posteriors, cuts) %>%
  lapply(function(ll) do.call(rbind, ll) %>% as.mcmc) %>%
  mcmc.list %>%
  gelman.diag(autoburnin = F, multivariate = F)

g <- gelman$psrf
for (p in c("xi", "gamma", "omega")) {
  if (!mean(g[grepl(p, rownames(g)), 1] > 1.1) <= .05) {
    # Save nonconverged posterior objects separately
    file.path(OUTDIR, "nonconverged") %>% dir.create(recursive = T, showWarnings = F)
    file.path(OUTDIR, "nonconverged", sprintf("%s_failed.RData", SAVE.NAME)) %>% save.image
    
    stop("Convergence check failed for " %^% p)
  }
}

file.path(OUTDIR, "mcmc_posteriors") %>% dir.create(recursive = T, showWarnings = F)
file.path(OUTDIR, "mcmc_posteriors", sprintf("post_%s.rds", SAVE.NAME)) %>%
  write_file(posteriors, .)

###
# Extract `xi`, our latent factor, and combine together all
# the runs
combined.posterior <- lapply(posteriors, function(o)  {
  m <- as.matrix(o)
  
  out <- m[, grepl("xi", colnames(m), perl = T)]
  colnames(out) <- sub("_xi.*$", "", colnames(out), perl = T)
  
  out
}) %>% do.call(rbind, .)

# Finally, set the rows where we had >50% missingness to NA
combined.posterior <- add_empty_cols(combined.posterior, dropped_dates)

full_names <- union(utable_names, colnames(combined.posterior))
sprintf("Found %d expanded country-dates", length(full_names)) %>% info

###
# Thin our `xi` posteriors once more for the HLIs. We're only going to
# grab 900 draws to match the dimensions of our z.sample files when
# constructing the HLIs.
if (nrow(combined.posterior) < 900)
  stop("Too few draws in posterior object")

# Stretch the thinned posteriors since we need to clean at least
# frefair according to elecreg.
idx <- seq(1, by = nrow(combined.posterior) / 900, length.out = 900)
thin.ma <- t(combined.posterior[idx, ]) %>%
  stretch(full_names) %>%
  pnorm

file.path(OUTDIR, "thin_post") %>% dir.create(recursive = T, showWarnings = F)
file.path(OUTDIR, "thin_post", sprintf("thin_post_%s.rds", SAVE.NAME)) %>%
  write_file(thin.ma, .)

###
# Summarise and generate point estimates at CD & CY-level.
final_cd.df <- dist_summary(combined.posterior, full_names)

b <- vapply(final_cd.df, is.numeric, logical(1))
final_cd.df[, b] <- lapply(final_cd.df[, b], pnorm)

file.path(OUTDIR, "results", "cd") %>% dir.create(recursive = T, showWarnings = F)
file.path(OUTDIR, "results", "cd", sprintf("bayes.%s_cd.rds", SAVE.NAME)) %>%
  write_file(final_cd.df, .)

final_cy.df <- cy.day_mean(final_cd.df, historical_date, country_text_id)

file.path(OUTDIR, "results", "cy") %>% dir.create(recursive = T, showWarnings = F)
file.path(OUTDIR, "results", "cy", sprintf("bayes.%s_cy.rds", SAVE.NAME)) %>%
  write_file(final_cy.df, .)

info("Finished!")

# Extract factor loadings and uniqueness scores from BFAs
#
# gamma[1] from mcmc-output are the factor loadings with our current bfa.jag
# script. gamma[1] is the slope and gamme[2] the intercept per input variable

# the order must be the same as how the BFA's were run
deps.ll <- list(academic_freedom_index = c("v2cafres", "v2cafexch", "v2cainsaut", "v2casurv"))


# filepath posterior file for corruption_index
f <- file.path(OUTDIR, "mcmc_posteriors", "post_academic_freedom_index.rds")

VAR <- gsub("post_", "", f) %>%
  basename %>%
  file_path_sans_ext

posteriors <- read_file(f)

combined.posterior <- lapply(posteriors, function(o)  {
  m <- as.matrix(o)
  m <- m[, grepl("gamma|omega", colnames(m))]
  m
}) %>% do.call(rbind, .)

colnames(combined.posterior)[grepl("gamma\\[\\d+\\,1\\]",
                                   colnames(combined.posterior))] <-
  paste0(deps.ll[[VAR]], "_intercept")

colnames(combined.posterior)[grepl("gamma\\[\\d+\\,2\\]",
                                   colnames(combined.posterior))] <-
  paste0(deps.ll[[VAR]], "_slope")

colnames(combined.posterior)[grepl("omega",
                                   colnames(combined.posterior))] <-
  paste0(deps.ll[[VAR]], "_uniqueness")

combined.posterior %<>% as.data.frame(stringsAsFactors = F)

fu <- function(x) {`^`(x = x, y = 2)}

res <- combined.posterior %>%
  summarize_all(list(median)) %>%
  mutate_at(vars(matches("_uniqueness")), fu)

write_file(res,
           file.path(OUTDIR, "factors_" %^% VAR %^% ".csv"),
           row.names = F)

res_table <- res %>%
  dplyr::select(-c(ends_with("intercept"))) %>%
  pivot_longer(cols= ends_with(c("slope", "uniqueness")), names_to = "Variable", values_to = "Values") 

res_table_uniqueness <- res_table %>%
  filter(str_detect(Variable, "uniqueness"))
res_table_loadings <- res_table %>%
  filter(str_detect(Variable, "slope"))

res_table <- cbind(res_table_loadings, res_table_uniqueness) 

stargazer::stargazer(res_table, summary = FALSE)

modelsummary::datasummary_df(res_table, output = "table/Table_F1.tex")
